Skip to content

Conversation

@hanhanW
Copy link
Contributor

@hanhanW hanhanW commented Sep 5, 2025

The revision adds a pattern that flattens 2 or more dimensional vector.to_elements ops by vector.shape_cast + vector.to_elements.

It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests.

It recovers the e2e lowering breakage from b4c31dc on LLVM path.

@llvmbot
Copy link
Member

llvmbot commented Sep 5, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Han-Chung Wang (hanhanW)

Changes

The revision adds a pattern that flattens 2 or more dimensional vector.to_elements ops by vector.shape_cast + vector.to_elements.

It also adds the lowering pattern to ConvertVectorToLLVMPass and complete the tests.

It recovers the e2e lowering breakage from b4c31dc on LLVM path.


Full diff: https://github.com/llvm/llvm-project/pull/156992.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h (+6)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+1)
  • (modified) mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp (+52)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir (+40)
  • (added) mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir (+22)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+24)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
index 47f96112a9433..e0f744841db2b 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h
@@ -311,6 +311,12 @@ void populateVectorToFromElementsToShuffleTreePatterns(
 void populateVectorFromElementsLoweringPatterns(RewritePatternSet &patterns,
                                                 PatternBenefit benefit = 1);
 
+/// Populate the pattern set with the following patterns:
+///
+/// [FlattenToElements]
+void populateVectorToElementsLoweringPatterns(RewritePatternSet &patterns,
+                                              PatternBenefit benefit = 1);
+
 /// Populate the pattern set with the following patterns:
 ///
 /// [ContractionOpToMatmulOpLowering]
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 9852df6970fdc..0b44ca7ceee42 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -95,6 +95,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorRankReducingFMAPattern(patterns);
     populateVectorGatherLoweringPatterns(patterns);
     populateVectorFromElementsLoweringPatterns(patterns);
+    populateVectorToElementsLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
         arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index acbf2b746037b..d74007f13a95b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
   LowerVectorScan.cpp
   LowerVectorShapeCast.cpp
   LowerVectorStep.cpp
+  LowerVectorToElements.cpp
   LowerVectorToFromElementsToShuffleTree.cpp
   LowerVectorTransfer.cpp
   LowerVectorTranspose.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
new file mode 100644
index 0000000000000..014034b8f9737
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToElements.cpp
@@ -0,0 +1,52 @@
+//===- LowerVectorToElements.cpp - Lower 'vector.to_elements' op ----------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.to_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-to-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Flattens 2 or more dimensional `vector.to_elements` ops by
+/// `vector.shape_cast` + `vector.to_elements`.
+struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ToElementsOp op,
+                                PatternRewriter &rewriter) const override {
+    VectorType vecType = op.getSource().getType();
+    if (vecType.getRank() <= 1)
+      return rewriter.notifyMatchFailure(
+          op, "the rank is already less than or equal to 1");
+    if (vecType.getNumScalableDims() > 0)
+      return rewriter.notifyMatchFailure(
+          op, "scalable vector is not yet supported");
+    auto vec1DType =
+        VectorType::get({vecType.getNumElements()}, vecType.getElementType());
+    Value shapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
+                                                  vec1DType, op.getSource());
+    rewriter.replaceOpWithNewOp<vector::ToElementsOp>(op, op.getResultTypes(),
+                                                      shapeCast);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorToElementsLoweringPatterns(
+    RewritePatternSet &patterns, PatternBenefit benefit) {
+  patterns.add<FlattenToElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
index 07d335117de01..bf4b05f7874de 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
@@ -1774,3 +1774,43 @@ func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> v
   %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
   return %0 : vector<2x1x2xf32>
 }
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// vector.to_elements
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %[[ARG0]][%[[C0]] : i64] : vector<2xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %[[ARG0]][%[[C1]] : i64] : vector<2xf32>
+// CHECK:         return %[[V0]], %[[V1]]
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// NOTE: We flatten multi-dimensional to_elements ops with pattern
+// `FlattenToElements` and then convert the 1-D to_elements ops to llvm.
+
+// CHECK-LABEL: func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>>
+// CHECK:         %[[C0:.+]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK:         %[[V0:.+]] = llvm.extractelement %{{.+}}[%[[C0]] : i64] : vector<4xf32>
+// CHECK:         %[[C1:.+]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK:         %[[V1:.+]] = llvm.extractelement %{{.+}}[%[[C1]] : i64] : vector<4xf32>
+// CHECK:         %[[C2:.+]] = llvm.mlir.constant(2 : i64) : i64
+// CHECK:         %[[V2:.+]] = llvm.extractelement %{{.+}}[%[[C2]] : i64] : vector<4xf32>
+// CHECK:         %[[C3:.+]] = llvm.mlir.constant(3 : i64) : i64
+// CHECK:         %[[V3:.+]] = llvm.extractelement %{{.+}}[%[[C3]] : i64] : vector<4xf32>
+// CHECK:         return %[[V0]], %[[V1]], %[[V2]], %[[V3]]
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
new file mode 100644
index 0000000000000..a57521c4db467
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-to-elements-lowering.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-flatten-vector-to-elements -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func.func @to_elements_1d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2xf32>
+// CHECK:         %[[RES:.+]]:2 = vector.to_elements %[[ARG0]] : vector<2xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1
+func.func @to_elements_1d(%arg0: vector<2xf32>) -> (f32, f32) {
+  %0:2 = vector.to_elements %arg0 : vector<2xf32>
+  return %0#0, %0#1 : f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @to_elements_2d(
+// CHECK-SAME:    %[[ARG0:.+]]: vector<2x2xf32>
+// CHECK:         %[[CAST:.+]] = vector.shape_cast %[[ARG0]]
+// CHECK:         %[[RES:.+]]:4 = vector.to_elements %[[CAST]] : vector<4xf32>
+// CHECK:         return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3
+func.func @to_elements_2d(%arg0: vector<2x2xf32>) -> (f32, f32, f32, f32) {
+  %0:4 = vector.to_elements %arg0 : vector<2x2xf32>
+  return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index bb1598ee3efe5..560a1331bdaf0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -808,6 +808,28 @@ struct TestUnrollVectorFromElements
   }
 };
 
+struct TestFlattenVectorToElements
+    : public PassWrapper<TestFlattenVectorToElements,
+                         OperationPass<func::FuncOp>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFlattenVectorToElements)
+
+  StringRef getArgument() const final {
+    return "test-flatten-vector-to-elements";
+  }
+  StringRef getDescription() const final {
+    return "Test flattening patterns for to_elements ops";
+  }
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<func::FuncDialect, vector::VectorDialect>();
+  }
+
+  void runOnOperation() override {
+    RewritePatternSet patterns(&getContext());
+    populateVectorToElementsLoweringPatterns(patterns);
+    (void)applyPatternsGreedily(getOperation(), std::move(patterns));
+  }
+};
+
 struct TestFoldArithExtensionIntoVectorContractPatterns
     : public PassWrapper<TestFoldArithExtensionIntoVectorContractPatterns,
                          OperationPass<func::FuncOp>> {
@@ -1083,6 +1105,8 @@ void registerTestVectorLowerings() {
 
   PassRegistration<TestUnrollVectorFromElements>();
 
+  PassRegistration<TestFlattenVectorToElements>();
+
   PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
 
   PassRegistration<TestVectorEmulateMaskedLoadStore>();

The revision adds a pattern that flattens 2 or more dimensional
`vector.to_elements` ops by `vector.shape_cast` + `vector.to_elements`.

It also adds the lowering pattern to ConvertVectorToLLVMPass and
complete the tests.

It recovers the e2e lowering breakage from llvm@b4c31dc on LLVM path.

Signed-off-by: hanhanW <[email protected]>
@hanhanW
Copy link
Contributor Author

hanhanW commented Sep 5, 2025

cc @yangtetris (I can't add you to reviews)


/// Flattens 2 or more dimensional `vector.to_elements` ops by
/// `vector.shape_cast` + `vector.to_elements`.
struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
struct FlattenToElements : OpRewritePattern<vector::ToElementsOp> {
struct FlattenToElements final : OpRewritePattern<vector::ToElementsOp> {

Comment on lines +34 to +36
if (vecType.getNumScalableDims() > 0)
return rewriter.notifyMatchFailure(
op, "scalable vector is not yet supported");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does from_elements support scalable vectors at all? https://mlir.llvm.org/docs/Dialects/Vector/#results-11

I think we can make it an assertion

void runOnOperation() override {
RewritePatternSet patterns(&getContext());
populateVectorToElementsLoweringPatterns(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use walkAndApplyPatterns here since we never have to revisit newly created ops

hanhanW added a commit to iree-org/iree that referenced this pull request Sep 5, 2025
It carries a cherry-pick fix that gets the operands from the adaptor:
-
iree-org/llvm-project@8b88014

Changes:
- Update most lit tests to check `vector.from_elements`.
- Add unrolling patterns to the final conversion.
- Implement n-D `vector::ToElementsOp` lowering, which will be dropped
after llvm/llvm-project#156992 is landed. It
should be added to all the backends, but somehow only AMDGPU backend
needs the pattern. The other backends may address the issue via
specialized tiling config + dropping vector unit dim patterns.

---------

Signed-off-by: hanhanW <[email protected]>
@yangtetris
Copy link
Contributor

Thanks for the fix! To be honest, I didn't realize that this canonicalization pattern also broke vector.to_elements...

Should we also add populateVectorToElementsLoweringPatterns to GpuToLLVMConversionPass and LowerGpuOpsToNVVMOpsPass just like #154774?

Copy link
Member

@Groverkss Groverkss left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be implemented the same way #151175 is implemented. We should be implementing a unrolling pattern for this. The rest of the ops do unrolling by default and we should keep it consistent.

@hanhanW
Copy link
Contributor Author

hanhanW commented Sep 5, 2025

I think flattening and unrolling are different approach, and people can make their own decisions. There may be cases that flattening innermost dims and unrolling the rest, but I don't have a use case so far. We will need unrolling for sure, and @amd-eochoalo will work on it, so I'll leave it to him.

@hanhanW hanhanW closed this Sep 5, 2025
@dcaballe
Copy link
Contributor

dcaballe commented Sep 5, 2025

Yes, we need both approaches and none of them are implemented for this op. Please, let me know if you plan to work on both or just one so that we plan accordingly.

@amd-eochoalo
Copy link
Contributor

@dcaballe I am building on top of this PR. I thought both patterns could be merged. I will be opening it up for review in a couple of minutes.

@amd-eochoalo
Copy link
Contributor

#157142

hhkit pushed a commit to opencompl/iree that referenced this pull request Sep 11, 2025
It carries a cherry-pick fix that gets the operands from the adaptor:
-
iree-org/llvm-project@8b88014

Changes:
- Update most lit tests to check `vector.from_elements`.
- Add unrolling patterns to the final conversion.
- Implement n-D `vector::ToElementsOp` lowering, which will be dropped
after llvm/llvm-project#156992 is landed. It
should be added to all the backends, but somehow only AMDGPU backend
needs the pattern. The other backends may address the issue via
specialized tiling config + dropping vector unit dim patterns.

---------

Signed-off-by: hanhanW <[email protected]>
Signed-off-by: Ivan Ho <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants